Napokon smo stigli do (veganskog?) "mesa" radionice! Ukratko, vizualizirat ćemo kitty mrežu preko ulaza u nju koji jako aktiviraju određene njene dijelove.
pip install --quiet torch-lucent
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Note: you may need to restart the kernel to use updated packages.
from lucent.optvis.transform import pad, jitter, random_rotate, random_scale
from lucent.optvis import render, param, transform, objectives
from math import sqrt, floor
import torchvision
import numpy as np
import torch
from lucent.optvis import render, param, transform, objectives
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import warnings
warnings.filterwarnings('ignore')
import torch.nn as nn
import torch.nn.functional as F
class KittyNet(nn.Module):
def __init__(self):
super().__init__()
self.bn0 = nn.BatchNorm2d(3)
self.conv1 = nn.Conv2d(3, 9, 3)
self.pool1 = nn.AvgPool2d(4, 4)
self.conv1_bn = nn.BatchNorm2d(9)
self.conv2 = nn.Conv2d(9, 16, 3)
self.pool2 = nn.AvgPool2d(4, 4)
self.conv2_bn = nn.BatchNorm2d(16)
self.conv3 = nn.Conv2d(16, 25, 3)
self.pool3 = nn.AvgPool2d(4, 4)
self.conv3_bn = nn.BatchNorm2d(25)
self.conv4 = nn.Conv2d(25, 36, 3)
self.pool4 = nn.AvgPool2d(2 , 2)
self.fc = nn.Linear(324, 4)
def forward(self, x):
x = self.bn0(x)
x = self.conv1_bn(self.pool1(F.relu(self.conv1(x))))
x = self.conv2_bn(self.pool2(F.relu(self.conv2(x))))
x = self.conv3_bn(self.pool3(F.relu(self.conv3(x))))
x = self.pool4(F.relu(self.conv4(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = self.fc(x)
return x
Učitavamo tri checkpointa: jedan na samom početku prije ikakvog treniranja, start, jedan nakon prve epohe early, i jedan na, praktički, kraju treniranja, late.
start_net = KittyNet()
start_net = start_net.to(device)
early_net = KittyNet()
early_net = early_net.to(device)
late_net = KittyNet()
late_net = late_net.to(device)
start_net.load_state_dict(torch.load('saved_models/kitty/epoch_0_batch_0.pth', map_location=device))
early_net.load_state_dict(torch.load('saved_models/kitty/epoch_1_batch_0.pth', map_location=device))
late_net.load_state_dict(torch.load('saved_models/kitty/epoch_7_batch_0.pth', map_location=device))
<All keys matched successfully>
start_net.to(device).eval()
early_net.to(device).eval()
late_net.to(device).eval()
KittyNet( (bn0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1)) (pool1): AvgPool2d(kernel_size=4, stride=4, padding=0) (conv1_bn): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1)) (pool2): AvgPool2d(kernel_size=4, stride=4, padding=0) (conv2_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(16, 25, kernel_size=(3, 3), stride=(1, 1)) (pool3): AvgPool2d(kernel_size=4, stride=4, padding=0) (conv3_bn): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4): Conv2d(25, 36, kernel_size=(3, 3), stride=(1, 1)) (pool4): AvgPool2d(kernel_size=2, stride=2, padding=0) (fc): Linear(in_features=324, out_features=4, bias=True) )
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
# da bilježnice budu manje:
%config InlineBackend.figure_format = 'jpg'
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
test_data = torchvision.datasets.ImageFolder(root='dataset/test', transform=transform_test)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=1, shuffle=True, num_workers=1)
def imshow(img, transpose = True):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(test_loader)
images, labels = dataiter.next()
imshow(images[0])
def show_activations(activations_dict, layer, grid_dims):
n_row = grid_dims
n_col = grid_dims
_, axs = plt.subplots(n_row, n_col, figsize=(19.55, 20))
axs = axs.flatten()
for ix, ax in zip(range(n_row*n_col), axs):
ax.matshow(activations_dict[layer][0].cpu()[ix, :, :], cmap='viridis')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xticks([])
ax.set_yticks([])
ax.margins(x=0, y=0, tight=True)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
start_net.conv1.register_forward_hook(get_activation('conv1'))
start_net.conv2.register_forward_hook(get_activation('conv2'))
start_net.conv3.register_forward_hook(get_activation('conv3'))
start_net.conv4.register_forward_hook(get_activation('conv4'))
output = start_net(images.cuda())
show_activations(activation, 'conv1', 3)
show_activations(activation, 'conv2', 4)
show_activations(activation, 'conv3', 5)
show_activations(activation, 'conv4', 6)
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
late_net.conv1.register_forward_hook(get_activation('conv1'))
late_net.conv2.register_forward_hook(get_activation('conv2'))
late_net.conv3.register_forward_hook(get_activation('conv3'))
late_net.conv4.register_forward_hook(get_activation('conv4'))
output = late_net(images.cuda())
show_activations(activation, 'conv1', 3)
show_activations(activation, 'conv2', 4)
show_activations(activation, 'conv3', 5)
Lucent je PyTorch library nastao na Tensorflow library Lucid, kojeg su razvili ljudi iz Google Braina za circuits research. On traži input koji "maksimizira" zadani channel u konvolucijskoj mreži, odnosno time nalazi onakav input koji, u određenom smislu, taj channel traži.
def lucent_show_layer(model, layer, n_channels,
param_f=None, transforms=None,
optimizer=None, preprocess=True, image_size=128):
n_row = floor( sqrt( n_channels ) )
n_col = floor( sqrt( n_channels ) )
_, axs = plt.subplots(n_row, n_col, figsize=(17.55, 18))
axs = axs.flatten()
for ix, ax in zip(range(n_row*n_col), axs):
img = render.render_vis(model, f"{layer}:{ix}", param_f=param_f,
transforms=transforms, preprocess=preprocess, progress=False, show_image=False)[0]
img = np.reshape(img, (image_size, image_size, 3))
ax.imshow(img)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xticks([])
ax.set_yticks([])
ax.margins(x=0, y=0, tight=True)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
Proći ćemo kroz mrežu sloj po sloj, uspoređujući u svakom koraku razliku između rezultata za start mrežu (prije početka treniranja), early mrežu (nakon prve epohe) i late mrežu (nakon skoro punih 8 epoha).
lucent_show_layer(start_net, 'conv1', 9,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(early_net, 'conv1', 9,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(late_net, 'conv1', 9,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(start_net, 'conv2', 16,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(early_net, 'conv2', 16,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(late_net, 'conv2', 16,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(start_net, 'conv3', 25,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(early_net, 'conv3', 25,
param_f=lambda: param.image(128),
image_size=128)
lucent_show_layer(late_net, 'conv3', 25,
param_f=lambda: param.image(128),
image_size=128)
Moguće je tražiti input koji istodobno maksimizira dva ili više channela.
channel = lambda n: objectives.channel("conv2", n)
obj = channel(1) + channel(4)
_ = render.render_vis(late_net, obj, show_inline=True)
100%|██████████| 512/512 [00:07<00:00, 70.03it/s]
channel = lambda n: objectives.channel("conv2", n)
obj = -channel(1) + channel(4)
_ = render.render_vis(late_net, obj, show_inline=True)
100%|██████████| 512/512 [00:07<00:00, 69.49it/s]
channel = lambda n: objectives.channel("conv3", n)
obj = channel(9) + channel(21)
_ = render.render_vis(late_net, obj, show_inline=True)
100%|██████████| 512/512 [00:07<00:00, 65.75it/s]
Ponavljamo praktički iste stvari prije, samo koristeći captum paket, koji na različite načine traži iste stvari.
!pip3 uninstall --quiet captum --y
!git clone https://github.com/pytorch/captum
%cd captum
!git checkout "optim-wip"
!pip3 install --quiet -e .
import sys
sys.path.append('/content/captum')
%cd ..
WARNING: Skipping captum as it is not installed. WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv fatal: destination path 'captum' already exists and is not an empty directory. /home/src/LUMEN-Interpretability/captum Already on 'optim-wip' Your branch is up to date with 'origin/optim-wip'. Obtaining file:///home/src/LUMEN-Interpretability/captum Preparing metadata (setup.py) ... done Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from captum==0.3.0) (3.3.4) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from captum==0.3.0) (1.19.5) Requirement already satisfied: torch>=1.2 in /usr/local/lib/python3.6/dist-packages (from captum==0.3.0) (1.10.1) Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.2->captum==0.3.0) (0.8) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.2->captum==0.3.0) (3.7.4.3) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum==0.3.0) (1.3.1) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum==0.3.0) (8.4.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum==0.3.0) (3.0.7) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum==0.3.0) (2.8.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum==0.3.0) (0.11.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib->captum==0.3.0) (1.15.0) Installing collected packages: captum Running setup.py develop for captum Successfully installed captum-0.3.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv /home/src/LUMEN-Interpretability
import captum.optim as optimviz
import torchvision
from typing import Callable, Iterable, Optional
def vis_neuron_large(
# funkcija adaptirana iz:
# https://colab.research.google.com/drive/1Zv7w03hOHfBWaEDMpSR1MA4D6IpAwZln
model: torch.nn.Module, target: torch.nn.Module, channel: int
) -> None:
image = optimviz.images.NaturalImage((640, 640)).to(device)
transforms = torch.nn.Sequential(
torch.nn.ReflectionPad2d(2),
optimviz.transforms.RandomSpatialJitter(8),
optimviz.transforms.RandomScale(scale=(2.15, 1.85, 2, 1.95, 2.05)),
torchvision.transforms.RandomRotation(degrees=(-15, 15)),
optimviz.transforms.RandomSpatialJitter(64),
optimviz.transforms.CenterCrop((640, 640)),
)
loss_fn = optimviz.loss.NeuronActivation(target, channel)
obj = optimviz.InputOptimization(model, loss_fn, image, transforms)
history = obj.optimize(optimviz.optimization.n_steps(512, False))
return image()
def visualize_layer_captum(model, layer, grid_dim):
n_row = grid_dim
n_col = grid_dim
_, axs = plt.subplots(n_row, n_col, figsize=(19.55, 20))
axs = axs.flatten()
for ix, ax in zip(range(n_row*n_col), axs):
img = vis_neuron_large(model, layer, ix)
img = img.permute(0, 2, 3, 1)
with torch.no_grad():
img = img.cpu().numpy()
img = img.reshape((640,640,3))
ax.imshow(img)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xticks([])
ax.set_yticks([])
ax.margins(x=0, y=0, tight=True)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
visualize_layer_captum(start_net, start_net.conv1, 3)
visualize_layer_captum(early_net, early_net.conv1, 3)
visualize_layer_captum(late_net, late_net.conv1, 3)
visualize_layer_captum(start_net, start_net.conv2, 4)
visualize_layer_captum(early_net, early_net.conv2, 4)
visualize_layer_captum(late_net, late_net.conv2, 4)
visualize_layer_captum(start_net, start_net.conv3, 5)
visualize_layer_captum(early_net, early_net.conv3, 5)
visualize_layer_captum(late_net, late_net.conv3, 5)
visualize_layer_captum(start_net, start_net.conv4, 6)
visualize_layer_captum(early_net, early_net.conv4, 6)
visualize_layer_captum(late_net, late_net.conv4, 6)